import copy
import torch.nn as  nn
import numpy as np

from allennlp.predictors import Predictor
from allennlp.data import Token

from tools.tokenizers import WordTokenizer, PretrainedTransformerTokenizer
from config import Config
from tools.utils import pos_tags, write_json

DEFAULT_SUPPORTED_POSTAG = [
    'CC',  # coordinating conjunction, like "and but neither versus whether yet so"
    # 'CD',   # Cardinal number, like "mid-1890 34 forty-two million dozen"
    'DT',  # Determiner, like all "an both those" 1
    'EX',  # Existential there, like "there" 1
    # 'FW',   # Foreign word
    'IN',  # Preposition or subordinating conjunction, like "among below into" 1
    'JJ',  # Adjective, like "second ill-mannered"
    'JJR',  # Adjective, comparative, like "colder"
    'JJS',  # Adjective, superlative, like "cheapest"
    # 'LS',   # List item marker, like "A B C D"
    'MD',  # Modal, like "can must shouldn't" 1
    'NN',  # Noun, singular or mass
    'NNS',  # Noun, plural
    'NNP',  # Proper noun, singular
    'NNPS',  # Proper noun, plural
    'PDT',  # Predeterminer, like "all both many" 1
    # 'POS',  # Possessive ending, like "'s"
    'PRP',  # Personal pronoun, like "hers herself ours they theirs" 1
    'PRP$',  # Possessive pronoun, like "hers his mine ours" 1
    'RB',  # Adverb
    'RBR',  # Adverb, comparative, like "lower heavier"
    'RBS',  # Adverb, superlative, like "best biggest"
    'RP',  # Particle, like "board about across around" 1
    # 'SYM',  # Symbol
    'TO',  # to 1
    # 'UH',   # Interjection, like "wow goody"
    'VB',  # Verb, base form
    'VBD',  # Verb, past tense
    'VBG',  # Verb, gerund or present participle
    'VBN',  # Verb, past participle
    'VBP',  # Verb, non-3rd person singular present
    'VBZ',  # Verb, 3rd person singular present
    'WDT',  # Wh-determiner, like "that what whatever which whichever" 1
    'WP',  # Wh-pronoun, like "that who" 1
    'WP$',  # Possessive wh-pronoun, like "whose" 1
    'WRB',  # Wh-adverb, like "however wherever whenever" 1
]

DEFAULT_IGNORED_TOKEN = [
    "@@NULL@@",
    ".",
    ",",
    ";",
    "!",
    "?",
    "[MASK]",
    "[SEP]",
    "[CLS]"
]


class Attacker(nn.Module):
    def __init__(self, cf: Config, predictor: Predictor):
        super(Attacker, self).__init__()
        self.cf = cf
        self.predictor = predictor

        self.stop_condition = 'threshold'

        self.instance_or_text = 'instance'

        self.supported_postag = DEFAULT_SUPPORTED_POSTAG
        self.ignored_token = DEFAULT_IGNORED_TOKEN

    def forward(self, *input):
        raise NotImplementedError

    def attack_num(self, len_instance):
        if self.cf.attack_ratio_or_num[self.cf.attacker] <= 1:
            attack_num = int(self.cf.attack_ratio_or_num[self.cf.attacker] * len_instance)
        else:
            attack_num = self.cf.attack_ratio_or_num[self.cf.attacker]

        return attack_num

    def stop(self, outputs, stop_condition=None):
        if stop_condition is None:
            stop_condition = self.stop_condition
        if stop_condition == 'flip':
            pred = [o['pred'] for o in outputs]
            gold = [o['gold'] for o in outputs]
            return [bool(abs(p - g) > 1e-5) for p, g in zip(pred, gold)]
        elif stop_condition == 'threshold':
            prob = [o['gold_prob'] for o in outputs]
            class_num = len(outputs[0]['logits'])
            threshold = 0.8 / class_num
            return [bool(p < threshold) for p in prob]
        else:
            raise ValueError(f'stop condition: {self.stop_condition} not implemtented.')

    def attack_result(self, success=None,
                      length=None,
                      adv_example=None):
        return {
            'success': success,
            'length': length,
            'adv_example': adv_example
        }


class BlackBoxAttacker(Attacker):
    def __init__(self, cf: Config, predictor: Predictor):
        super(BlackBoxAttacker, self).__init__(cf, predictor)
        self.instance_or_text = 'text'

        self.tokenizer = WordTokenizer(pos_tags=True)

    def forward(self, text):

        text['sentence'] = [t.text for t in self.tokenizer.tokenize(text['sentence'])]
        text['tag'] = pos_tags([t for t in text['sentence']])

        H = self.get_victim_substitute_pair(text)

        candidates = [self.copy_text(text)]

        for p, w in H:
            candidates.append(self.subsitude(candidates[-1], p, w))

        outputs = self.predict_batch_data(candidates)

        stop = np.nonzero(self.stop(outputs))[0]

        length = stop[0] if len(stop) > 0 else len(outputs) - 1
        adv_text = candidates[length]
        adv_text['sentence'] = self.tokenizer.detokenize(adv_text['sentence'])
        del adv_text['tag']
        gold = outputs[0]['gold']
        pred = outputs[length]['pred']
        success = bool(abs(gold - pred) > 1e-3)

        return self.attack_result(success=success,
                                  length=length / len(text['sentence']),
                                  adv_example=adv_text)

    def get_victim_substitute_pair(self, **inputs):
        raise NotImplementedError

    def predict_batch_data(self, texts):
        texts = self.detokenize_texts(texts)
        outputs = self.predictor.predict_many_json(texts)
        return outputs

    def detokenize_texts(self, texts):
        new_texts = [self.copy_text(t) for t in texts]
        for text in new_texts:
            text['sentence'] = self.tokenizer.detokenize(text['sentence'])
        return new_texts

    def subsitude(self, instance, i, word):
        new_instance = self.copy_text(instance)
        if word is not None:
            new_instance['sentence'][i] = word
        return new_instance

    def copy_text(self, text):
        new_text = {}
        new_text['sentence'] = text['sentence'] + []
        new_text['tag'] = text['tag'] + []
        new_text['label'] = text['label']
        return new_text
